-
Notifications
You must be signed in to change notification settings - Fork 12.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
AMDGPU: Define v_mfma_f32_{16x16x128|32x32x64}_f8f6f4 instructions #116723
AMDGPU: Define v_mfma_f32_{16x16x128|32x32x64}_f8f6f4 instructions #116723
Conversation
@llvm/pr-subscribers-clang-codegen @llvm/pr-subscribers-llvm-ir Author: Matt Arsenault (arsenm) ChangesThese use a new VOP3PX encoding for the v_mfma_scale_* instructions, I'm not sure the intrinsic should really expose op_sel (or any of the The op_sel syntax also seems extra horrible in this usage, especially with the Patch is 250.47 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/116723.diff 23 Files Affected:
diff --git a/llvm/docs/AMDGPUUsage.rst b/llvm/docs/AMDGPUUsage.rst
index a25b6feddbeddc..4bd45e7d3856a6 100644
--- a/llvm/docs/AMDGPUUsage.rst
+++ b/llvm/docs/AMDGPUUsage.rst
@@ -1397,6 +1397,19 @@ The AMDGPU backend implements the following LLVM IR intrinsics.
used by hardware to control active lanes when used in EXEC register.
For example, ballot(i1 true) return EXEC mask.
+ llvm.amdgcn.mfma.f32.16x16x128.f8f6f4.scaled Emit `v_mfma_f32_16x16x128_f8f6f4`, bundled with a `v_mfma_ld_scale_b32`
+ to set the scale factor. The last 4 operands correspond to the inputs
+ to `v_mfma_ld_scale_b32`:
+
+ llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4 Emit `v_mfma_scale_f32_16x16x128_f8f6f4` to set the scale factor. The
+ last 4 operands correspond to the scale inputs.
+ 2-bit byte index to use for each lane for matrix A
+ Matrix A scale values
+ 2-bit byte index to use for each lane for matrix B
+ Matrix B scale values
+
+ llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4 Emit `v_mfma_scale_f32_32x32x64_f8f6f4`
+
============================================== ==========================================================
.. TODO::
diff --git a/llvm/include/llvm/IR/IntrinsicsAMDGPU.td b/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
index 360af786c5160d..596dc3c9656244 100644
--- a/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
+++ b/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
@@ -2968,6 +2968,27 @@ class AMDGPUMfmaIntrinsic<LLVMType DestTy, LLVMType SrcABTy> :
[IntrConvergent, IntrNoMem,
ImmArg<ArgIndex<3>>, ImmArg<ArgIndex<4>>, ImmArg<ArgIndex<5>>]>;
+
+class AMDGPUMfmaScaleIntrinsic<LLVMType DestTy, LLVMType SrcABTy> :
+ ClangBuiltin<!subst("int", "__builtin", NAME)>,
+ DefaultAttrsIntrinsic<[DestTy],
+ [SrcABTy, SrcABTy, DestTy,
+ llvm_i32_ty, // cbsz
+ llvm_i32_ty, // abid
+ llvm_i32_ty, // blgp
+ // llvm_i1_ty, // TODO: neg_src2
+ // llvm_i1_ty, // TODO: abs_src2
+ // llvm_i1_ty, // TODO: clamp
+ llvm_i32_ty, // op_sel (A matrix scale, 2-bits) // TODO: Make i2?
+ llvm_i32_ty, // v_mfma_ld_scale_b32 src0 (A matrix scale)
+ llvm_i32_ty, // op_sel (B matrix scale, 2-bits) // TODO: Make i2?
+ llvm_i32_ty // v_mfma_ld_scale_b32 src1 (B matrix scale)
+ ],
+ [IntrConvergent, IntrNoMem,
+ ImmArg<ArgIndex<3>>, ImmArg<ArgIndex<4>>, ImmArg<ArgIndex<5>>,
+ ImmArg<ArgIndex<6>>, ImmArg<ArgIndex<8>>
+ ]>;
+
defset list<Intrinsic> AMDGPUMFMAIntrinsics908 = {
def int_amdgcn_mfma_f32_32x32x1f32 : AMDGPUMfmaIntrinsic<llvm_v32f32_ty, llvm_float_ty>;
def int_amdgcn_mfma_f32_16x16x1f32 : AMDGPUMfmaIntrinsic<llvm_v16f32_ty, llvm_float_ty>;
@@ -3119,6 +3140,8 @@ def int_amdgcn_mfma_f32_16x16x32_f16 : AMDGPUMfmaIntrinsic<llvm_v4f32_ty, llvm_v
def int_amdgcn_mfma_f32_32x32x16_f16 : AMDGPUMfmaIntrinsic<llvm_v16f32_ty, llvm_v8f16_ty>;
def int_amdgcn_mfma_f32_32x32x16_bf16 : AMDGPUMfmaIntrinsic<llvm_v16f32_ty, llvm_v8bf16_ty>;
+def int_amdgcn_mfma_scale_f32_16x16x128_f8f6f4 : AMDGPUMfmaScaleIntrinsic<llvm_v4f32_ty, llvm_v8i32_ty>;
+def int_amdgcn_mfma_scale_f32_32x32x64_f8f6f4 : AMDGPUMfmaScaleIntrinsic<llvm_v16f32_ty, llvm_v8i32_ty>;
}
//===----------------------------------------------------------------------===//
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUGISel.td b/llvm/lib/Target/AMDGPU/AMDGPUGISel.td
index d348f489d95dd3..88fa96bd049f29 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUGISel.td
+++ b/llvm/lib/Target/AMDGPU/AMDGPUGISel.td
@@ -423,3 +423,6 @@ def gi_fp_pow2_to_exponent : GICustomOperandRenderer<"renderFPPow2ToExponent">,
def gi_as_hw_round_mode : GICustomOperandRenderer<"renderRoundMode">,
GISDNodeXFormEquiv<as_hw_round_mode>;
+
+def gi_MFMALdScaleModifierOp : GICustomOperandRenderer<"renderScaledMAIIntrinsicOperand">,
+ GISDNodeXFormEquiv<MFMALdScaleXForm>;
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp b/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp
index 3522ece24f1c45..58125f7f0831c6 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp
@@ -5737,6 +5737,18 @@ void AMDGPUInstructionSelector::renderRoundMode(MachineInstrBuilder &MIB,
MIB.addImm((MI.getOperand(OpIdx).getImm() + 3) % 4);
}
+/// Convert from 2-bit value to enum values used for op_sel* source modifiers.
+void AMDGPUInstructionSelector::renderScaledMAIIntrinsicOperand(
+ MachineInstrBuilder &MIB, const MachineInstr &MI, int OpIdx) const {
+ unsigned Val = MI.getOperand(OpIdx).getImm();
+ unsigned New = 0;
+ if (Val & 0x1)
+ New |= SISrcMods::OP_SEL_0;
+ if (Val & 0x2)
+ New |= SISrcMods::OP_SEL_1;
+ MIB.addImm(New);
+}
+
bool AMDGPUInstructionSelector::isInlineImmediate(const APInt &Imm) const {
return TII.isInlineConstant(Imm);
}
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.h b/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.h
index 42343104812b66..563e40267f04b1 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.h
+++ b/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.h
@@ -364,6 +364,8 @@ class AMDGPUInstructionSelector final : public InstructionSelector {
void renderRoundMode(MachineInstrBuilder &MIB, const MachineInstr &MI,
int OpIdx) const;
+ void renderScaledMAIIntrinsicOperand(MachineInstrBuilder &MIB,
+ const MachineInstr &MI, int OpIdx) const;
bool isInlineImmediate(const APInt &Imm) const;
bool isInlineImmediate(const APFloat &Imm) const;
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUInstructions.td b/llvm/lib/Target/AMDGPU/AMDGPUInstructions.td
index 671070c70f0c41..6a5065cd4a0e8f 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUInstructions.td
+++ b/llvm/lib/Target/AMDGPU/AMDGPUInstructions.td
@@ -40,7 +40,7 @@ class AMDGPUInst <dag outs, dag ins, string asm = "",
// instructions to not match without killing the whole decode process. It is
// mainly used for ARM, but Tablegen expects this field to exist or it fails
// to build the decode table.
- field bits<96> SoftFail = 0;
+ field bits<128> SoftFail = 0; // FIXME: If this is smaller than largest instruction, DecodeEmitter crashes
let DecoderNamespace = Namespace;
diff --git a/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp
index b648b68f3bd2b0..0467d26bd3093d 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp
@@ -4769,6 +4769,25 @@ AMDGPURegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
: getVGPROpMapping(MI.getOperand(4).getReg(), MRI, *TRI);
break;
}
+ case Intrinsic::amdgcn_mfma_scale_f32_16x16x128_f8f6f4:
+ case Intrinsic::amdgcn_mfma_scale_f32_32x32x64_f8f6f4: {
+ const SIMachineFunctionInfo *Info = MF.getInfo<SIMachineFunctionInfo>();
+ OpdsMapping[0] =
+ Info->mayNeedAGPRs()
+ ? getAGPROpMapping(MI.getOperand(0).getReg(), MRI, *TRI)
+ : getVGPROpMapping(MI.getOperand(0).getReg(), MRI, *TRI);
+
+ OpdsMapping[2] = getVGPROpMapping(MI.getOperand(2).getReg(), MRI, *TRI);
+ OpdsMapping[3] = getVGPROpMapping(MI.getOperand(3).getReg(), MRI, *TRI);
+ OpdsMapping[4] =
+ Info->mayNeedAGPRs()
+ ? getAGPROpMapping(MI.getOperand(4).getReg(), MRI, *TRI)
+ : getVGPROpMapping(MI.getOperand(4).getReg(), MRI, *TRI);
+
+ OpdsMapping[9] = getVGPROpMapping(MI.getOperand(9).getReg(), MRI, *TRI);
+ OpdsMapping[11] = getVGPROpMapping(MI.getOperand(11).getReg(), MRI, *TRI);
+ break;
+ }
case Intrinsic::amdgcn_smfmac_f32_16x16x32_f16:
case Intrinsic::amdgcn_smfmac_f32_32x32x16_f16:
case Intrinsic::amdgcn_smfmac_f32_16x16x32_bf16:
diff --git a/llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.cpp b/llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.cpp
index 7c293c1a5e512a..37cff2b9e1c959 100644
--- a/llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.cpp
+++ b/llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.cpp
@@ -493,6 +493,17 @@ static inline DecoderUInt128 eat12Bytes(ArrayRef<uint8_t> &Bytes) {
return DecoderUInt128(Lo, Hi);
}
+static inline DecoderUInt128 eat16Bytes(ArrayRef<uint8_t> &Bytes) {
+ assert(Bytes.size() >= 16);
+ uint64_t Lo =
+ support::endian::read<uint64_t, llvm::endianness::little>(Bytes.data());
+ Bytes = Bytes.slice(8);
+ uint64_t Hi =
+ support::endian::read<uint64_t, llvm::endianness::little>(Bytes.data());
+ Bytes = Bytes.slice(8);
+ return DecoderUInt128(Lo, Hi);
+}
+
DecodeStatus AMDGPUDisassembler::getInstruction(MCInst &MI, uint64_t &Size,
ArrayRef<uint8_t> Bytes_,
uint64_t Address,
@@ -529,6 +540,15 @@ DecodeStatus AMDGPUDisassembler::getInstruction(MCInst &MI, uint64_t &Size,
// Reinitialize Bytes
Bytes = Bytes_.slice(0, MaxInstBytesNum);
+
+ } else if (Bytes.size() >= 16 &&
+ STI.hasFeature(AMDGPU::FeatureGFX950Insts)) {
+ DecoderUInt128 DecW = eat16Bytes(Bytes);
+ if (tryDecodeInst(DecoderTableGFX940128, MI, DecW, Address, CS))
+ break;
+
+ // Reinitialize Bytes
+ Bytes = Bytes_.slice(0, MaxInstBytesNum);
}
if (Bytes.size() >= 8) {
diff --git a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCAsmInfo.cpp b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCAsmInfo.cpp
index 3c9f6d2938075b..56ed29ede02c23 100644
--- a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCAsmInfo.cpp
+++ b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCAsmInfo.cpp
@@ -59,6 +59,10 @@ unsigned AMDGPUMCAsmInfo::getMaxInstLength(const MCSubtargetInfo *STI) const {
if (STI->hasFeature(AMDGPU::FeatureNSAEncoding))
return 20;
+ // VOP3PX encoding.
+ if (STI->hasFeature(AMDGPU::FeatureGFX950Insts))
+ return 16;
+
// 64-bit instruction with 32-bit literal.
if (STI->hasFeature(AMDGPU::FeatureVOP3Literal))
return 12;
diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
index 5b02f9bf80d3fc..1e0dc7d1fd9bd6 100644
--- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
@@ -15449,6 +15449,23 @@ void SITargetLowering::AdjustInstrPostInstrSelection(MachineInstr &MI,
MRI.setRegClass(Op.getReg(), NewRC);
}
+ if (TII->isMAI(MI)) {
+ // The ordinary src0, src1, src2 were legalized above.
+ //
+ // We have to also legalize the appended v_mfma_ld_scale_b32 operands,
+ // as a separate instruction.
+ int Src0Idx = AMDGPU::getNamedOperandIdx(MI.getOpcode(),
+ AMDGPU::OpName::scale_src0);
+ if (Src0Idx != -1) {
+ int Src1Idx = Src0Idx + 2;
+ assert(Src1Idx = AMDGPU::getNamedOperandIdx(
+ MI.getOpcode(), AMDGPU::OpName::scale_src1));
+ if (TII->usesConstantBus(MRI, MI, Src0Idx) &&
+ TII->usesConstantBus(MRI, MI, Src1Idx))
+ TII->legalizeOpWithMove(MI, Src1Idx);
+ }
+ }
+
if (!HasAGPRs)
return;
diff --git a/llvm/lib/Target/AMDGPU/SIInstrFormats.td b/llvm/lib/Target/AMDGPU/SIInstrFormats.td
index dd1ab2c628715d..267c9a94b90968 100644
--- a/llvm/lib/Target/AMDGPU/SIInstrFormats.td
+++ b/llvm/lib/Target/AMDGPU/SIInstrFormats.td
@@ -300,6 +300,11 @@ class Enc96 {
int Size = 12;
}
+class Enc128 {
+ field bits<128> Inst;
+ int Size = 16;
+}
+
def CPolBit {
int GLC = 0;
int SLC = 1;
diff --git a/llvm/lib/Target/AMDGPU/SIInstrInfo.h b/llvm/lib/Target/AMDGPU/SIInstrInfo.h
index 1f7fff76d15210..e55418326a4bd0 100644
--- a/llvm/lib/Target/AMDGPU/SIInstrInfo.h
+++ b/llvm/lib/Target/AMDGPU/SIInstrInfo.h
@@ -1115,6 +1115,12 @@ class SIInstrInfo final : public AMDGPUGenInstrInfo {
const MachineOperand &MO,
const MCOperandInfo &OpInfo) const;
+ bool usesConstantBus(const MachineRegisterInfo &MRI, const MachineInstr &MI,
+ int OpIdx) const {
+ return usesConstantBus(MRI, MI.getOperand(OpIdx),
+ MI.getDesc().operands()[OpIdx]);
+ }
+
/// Return true if this instruction has any modifiers.
/// e.g. src[012]_mod, omod, clamp.
bool hasModifiers(unsigned Opcode) const;
diff --git a/llvm/lib/Target/AMDGPU/SIInstrInfo.td b/llvm/lib/Target/AMDGPU/SIInstrInfo.td
index d2024cf915874d..08e29ab78cceaf 100644
--- a/llvm/lib/Target/AMDGPU/SIInstrInfo.td
+++ b/llvm/lib/Target/AMDGPU/SIInstrInfo.td
@@ -914,6 +914,16 @@ def fp16_zeros_high_16bits : PatLeaf<(f16 VGPR_32:$src), [{
return fp16SrcZerosHighBits(N->getOpcode());
}]>;
+def MFMALdScaleXForm : SDNodeXForm<timm, [{
+ unsigned Val = N->getZExtValue();
+ unsigned New = 0;
+ if (Val & 0x1)
+ New |= SISrcMods::OP_SEL_0;
+ if (Val & 0x2)
+ New |= SISrcMods::OP_SEL_1;
+ return CurDAG->getTargetConstant(New, SDLoc(N), MVT::i32);
+}]>;
+
def is_canonicalized : PatLeaf<(fAny srcvalue:$src), [{
const SITargetLowering &Lowering =
*static_cast<const SITargetLowering *>(getTargetLowering());
@@ -1515,6 +1525,10 @@ class PackedIntInputMods <PackedIntInputModsMatchClass matchClass> : InputMods <
def PackedF16InputMods : PackedFPInputMods<PackedF16InputModsMatchClass>;
def PackedI16InputMods : PackedIntInputMods<PackedI16InputModsMatchClass>;
+def MFMALdScaleModifierOp : TImmLeaf<i32, [{
+ return isUInt<2>(Imm);
+}], MFMALdScaleXForm>;
+
//===----------------------------------------------------------------------===//
// Complex patterns
//===----------------------------------------------------------------------===//
@@ -2851,6 +2865,8 @@ def VOP_V16F32_V2I32_V4I32_I32 : VOPProfile <[v16f32, v2i32, v4i32, i32]>;
def VOP_V4F32_V8F16_V8F16_V4F32 : VOPProfile <[v4f32, v8f16, v8f16, v4f32]>;
def VOP_V16F32_V8F16_V8F16_V16F32 : VOPProfile <[v16f32, v8f16, v8f16, v16f32]>;
def VOP_V16F32_V8BF16_V8BF16_V16F32 : VOPProfile <[v16f32, v8bf16, v8bf16, v16f32]>;
+def VOP_V4F32_V8I32_V8I32_V4F32 : VOPProfile <[v4f32, v8i32, v8i32, v4f32]>;
+def VOP_V16F32_V8I32_V8I32_V16F32 : VOPProfile <[v16f32, v8i32, v8i32, v16f32]>;
class Commutable_REV <string revOp, bit isOrig> {
diff --git a/llvm/lib/Target/AMDGPU/SIRegisterInfo.td b/llvm/lib/Target/AMDGPU/SIRegisterInfo.td
index d47ff9fe96c94e..03d85b6081eada 100644
--- a/llvm/lib/Target/AMDGPU/SIRegisterInfo.td
+++ b/llvm/lib/Target/AMDGPU/SIRegisterInfo.td
@@ -1338,11 +1338,13 @@ class AVSrcOperand<RegisterClass regClass, string width>
def AVSrc_32 : AVSrcOperand<AV_32, "OPW32">;
def AVSrc_64 : AVSrcOperand<AV_64, "OPW64">;
def AVSrc_128 : AVSrcOperand<AV_128, "OPW128">;
+def AVSrc_256 : AVSrcOperand<AV_256, "OPW256">;
class AVDstOperand<RegisterClass regClass, string width>
: AVOperand<regClass, "decodeAV10", width>;
def AVDst_128 : AVDstOperand<AV_128, "OPW128">;
+def AVDst_256 : AVDstOperand<AV_256, "OPW256">;
def AVDst_512 : AVDstOperand<AV_512, "OPW512">;
class AVLdStOperand<RegisterClass regClass, string width>
diff --git a/llvm/lib/Target/AMDGPU/VOP3PInstructions.td b/llvm/lib/Target/AMDGPU/VOP3PInstructions.td
index 3a6202ea435222..9d39d3654c3ff0 100644
--- a/llvm/lib/Target/AMDGPU/VOP3PInstructions.td
+++ b/llvm/lib/Target/AMDGPU/VOP3PInstructions.td
@@ -639,14 +639,26 @@ def VOPProfileMAI_F32_V8F16_X16_VCD : VOPProfileMAI<VOP_V16F32_V8F16_V8F16_V16F3
def VOPProfileMAI_F32_V8BF16_X16 : VOPProfileMAI<VOP_V16F32_V8BF16_V8BF16_V16F32, AISrc_512_f32, ADst_512, AVSrc_128>;
def VOPProfileMAI_F32_V8BF16_X16_VCD : VOPProfileMAI<VOP_V16F32_V8BF16_V8BF16_V16F32, VISrc_512_f32, VDst_512, AVSrc_128>;
+// For f32_16x16x128_f8f6f4
+def VOPProfileMAI_F32_V8I32_X128 : VOPProfileMAI<VOP_V4F32_V8I32_V8I32_V4F32, AISrc_128_f32, ADst_128, AVSrc_256>;
+def VOPProfileMAI_F32_V8I32_X128_VCD : VOPProfileMAI<VOP_V4F32_V8I32_V8I32_V4F32, VISrc_128_f32, VDst_128, AVSrc_256>;
+
+// For f32_32x32x64_f8f6f4
+def VOPProfileMAI_F32_V8I32_X512 : VOPProfileMAI<VOP_V16F32_V8I32_V8I32_V16F32, AISrc_512_f32, ADst_512, AVSrc_256>;
+def VOPProfileMAI_F32_V8I32_X512_VCD : VOPProfileMAI<VOP_V16F32_V8I32_V8I32_V16F32, VISrc_512_f32, VDst_512, AVSrc_256>;
+
class MFMATable <bit is_mac, string Name> {
bit IsMac = is_mac;
string FMAOp = Name;
}
-class MAIFrag<SDPatternOperator Op, code pred> : PatFrag <
- (ops node:$src0, node:$src1, node:$src2, node:$cbsz, node:$abid, node:$blgp),
- (Op $src0, $src1, $src2, $cbsz, $abid, $blgp),
+class MAIFrag<SDPatternOperator Op, code pred, bit Scaled = false> : PatFrag <
+ !if(Scaled, (ops node:$src0, node:$src1, node:$src2, node:$cbsz, node:$abid, node:$blgp,
+ node:$scale_src0_opsel, node:$scale_src0,
+ node:$scale_src1_opsel, node:$scale_src1),
+ (ops node:$src0, node:$src1, node:$src2, node:$cbsz, node:$abid, node:$blgp)),
+ !if(Scaled, (Op $src0, $src1, $src2, $cbsz, $abid, $blgp, $scale_src0_opsel, $scale_src0, $scale_src1_opsel, $scale_src1),
+ (Op $src0, $src1, $src2, $cbsz, $abid, $blgp)),
pred
>;
@@ -666,11 +678,11 @@ defvar MayNotNeedAGPRs_gisel = [{
return !MF.getInfo<SIMachineFunctionInfo>()->mayNeedAGPRs();
}];
-class AgprMAIFrag<SDPatternOperator Op> : MAIFrag<Op, MayNeedAGPRs> {
+class AgprMAIFrag<SDPatternOperator Op, bit Scaled = false> : MAIFrag<Op, MayNeedAGPRs, Scaled> {
let GISelPredicateCode = MayNeedAGPRs_gisel;
}
-class VgprMAIFrag<SDPatternOperator Op> : MAIFrag<Op, MayNotNeedAGPRs> {
+class VgprMAIFrag<SDPatternOperator Op, bit Scaled = false> : MAIFrag<Op, MayNotNeedAGPRs, Scaled> {
let GISelPredicateCode = MayNotNeedAGPRs_gisel;
}
@@ -683,26 +695,47 @@ let isAsCheapAsAMove = 1, isReMaterializable = 1 in {
} // End isMoveImm = 1
} // End isAsCheapAsAMove = 1, isReMaterializable = 1
-class MAIInst<string OpName, VOPProfile P, SDPatternOperator node>
- : VOP3InstBase<OpName, P, node> {
+class MAIInst<string OpName, VOPProfile P, SDPatternOperator node, bit Scaled = false>
+ : VOP3InstBase<OpName, P, node, /*IsVOP2=*/0, Scaled> {
Instruction Opcode = !cast<Instruction>(NAME);
bit is_dgemm = 0;
bit is_gfx940_xdl = 0;
}
-multiclass MAIInst<string OpName, string P, SDPatternOperator node> {
- defvar NoDstOverlap = !cast<VOPProfileMAI>("VOPProfileMAI_" # P).NoDstOverlap;
-
+// FIXME: Intrinsic should probably not have op_sel operands, we can
+// pattern match byte select patterns into op_sel.
+// FIXME: Missing neg and clamp modifiers
+//
+// FIXME: Usual syntax for op_sel is quite hostile here.
+class ScaledMAIInst<string OpName, MAIInst BaseInst, SDPatternOperator node> :
+ MAIInst<OpName, BaseInst.Pfl, node, /*Scaled=*/true> {
+ // Append operands from V_MFMA_LD_SCALE_B32, but we need to rename them.
+ let InOperandList = !con(BaseInst.InOperandList,
+ (ins VSrc_b32:$scale_src0,
+ VSrc_b32:$scale_src1,
+ op_sel0:$scale_src0_opsel,
+ op_sel_hi0:$scale_src1_opsel));
+ let AsmOperands =
+ "$vdst, $src0, $src1, $src2, $scale_src0, $scale_src1"
+ "$scale_src0_opsel$scale_src1_opsel$cbsz$abid$blgp";
+
+ let FixedSize = 1;
+ let Size = 16;
+}
+
+multiclass MAIInst<string OpName, string P, SDPatternOperator node = null_frag,
+ bit NoDstOverlap = !cast<VOPProfileMAI>("VOPProfileMAI_" # P).NoDstOverlap,
+ bit Scaled = false> {
let isConvergent = 1, mayRaiseFPException = 0, ReadsModeReg = 1 in {
// FP32 denorm mode is respected, rounding mode is not. Exceptions are not supported.
let Constraints = !if(NoDstOverlap, "@earlyclobber $vdst", "") in {
def _e64 : MAIInst<OpName, !cast<VOPProfileMAI>("VOPProfileMAI_" # P),
- !if(!or(NoDstOverlap, !eq(node, null_frag)), nul...
[truncated]
|
@llvm/pr-subscribers-mc Author: Matt Arsenault (arsenm) ChangesThese use a new VOP3PX encoding for the v_mfma_scale_* instructions, I'm not sure the intrinsic should really expose op_sel (or any of the The op_sel syntax also seems extra horrible in this usage, especially with the Patch is 250.47 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/116723.diff 23 Files Affected:
diff --git a/llvm/docs/AMDGPUUsage.rst b/llvm/docs/AMDGPUUsage.rst
index a25b6feddbeddc..4bd45e7d3856a6 100644
--- a/llvm/docs/AMDGPUUsage.rst
+++ b/llvm/docs/AMDGPUUsage.rst
@@ -1397,6 +1397,19 @@ The AMDGPU backend implements the following LLVM IR intrinsics.
used by hardware to control active lanes when used in EXEC register.
For example, ballot(i1 true) return EXEC mask.
+ llvm.amdgcn.mfma.f32.16x16x128.f8f6f4.scaled Emit `v_mfma_f32_16x16x128_f8f6f4`, bundled with a `v_mfma_ld_scale_b32`
+ to set the scale factor. The last 4 operands correspond to the inputs
+ to `v_mfma_ld_scale_b32`:
+
+ llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4 Emit `v_mfma_scale_f32_16x16x128_f8f6f4` to set the scale factor. The
+ last 4 operands correspond to the scale inputs.
+ 2-bit byte index to use for each lane for matrix A
+ Matrix A scale values
+ 2-bit byte index to use for each lane for matrix B
+ Matrix B scale values
+
+ llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4 Emit `v_mfma_scale_f32_32x32x64_f8f6f4`
+
============================================== ==========================================================
.. TODO::
diff --git a/llvm/include/llvm/IR/IntrinsicsAMDGPU.td b/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
index 360af786c5160d..596dc3c9656244 100644
--- a/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
+++ b/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
@@ -2968,6 +2968,27 @@ class AMDGPUMfmaIntrinsic<LLVMType DestTy, LLVMType SrcABTy> :
[IntrConvergent, IntrNoMem,
ImmArg<ArgIndex<3>>, ImmArg<ArgIndex<4>>, ImmArg<ArgIndex<5>>]>;
+
+class AMDGPUMfmaScaleIntrinsic<LLVMType DestTy, LLVMType SrcABTy> :
+ ClangBuiltin<!subst("int", "__builtin", NAME)>,
+ DefaultAttrsIntrinsic<[DestTy],
+ [SrcABTy, SrcABTy, DestTy,
+ llvm_i32_ty, // cbsz
+ llvm_i32_ty, // abid
+ llvm_i32_ty, // blgp
+ // llvm_i1_ty, // TODO: neg_src2
+ // llvm_i1_ty, // TODO: abs_src2
+ // llvm_i1_ty, // TODO: clamp
+ llvm_i32_ty, // op_sel (A matrix scale, 2-bits) // TODO: Make i2?
+ llvm_i32_ty, // v_mfma_ld_scale_b32 src0 (A matrix scale)
+ llvm_i32_ty, // op_sel (B matrix scale, 2-bits) // TODO: Make i2?
+ llvm_i32_ty // v_mfma_ld_scale_b32 src1 (B matrix scale)
+ ],
+ [IntrConvergent, IntrNoMem,
+ ImmArg<ArgIndex<3>>, ImmArg<ArgIndex<4>>, ImmArg<ArgIndex<5>>,
+ ImmArg<ArgIndex<6>>, ImmArg<ArgIndex<8>>
+ ]>;
+
defset list<Intrinsic> AMDGPUMFMAIntrinsics908 = {
def int_amdgcn_mfma_f32_32x32x1f32 : AMDGPUMfmaIntrinsic<llvm_v32f32_ty, llvm_float_ty>;
def int_amdgcn_mfma_f32_16x16x1f32 : AMDGPUMfmaIntrinsic<llvm_v16f32_ty, llvm_float_ty>;
@@ -3119,6 +3140,8 @@ def int_amdgcn_mfma_f32_16x16x32_f16 : AMDGPUMfmaIntrinsic<llvm_v4f32_ty, llvm_v
def int_amdgcn_mfma_f32_32x32x16_f16 : AMDGPUMfmaIntrinsic<llvm_v16f32_ty, llvm_v8f16_ty>;
def int_amdgcn_mfma_f32_32x32x16_bf16 : AMDGPUMfmaIntrinsic<llvm_v16f32_ty, llvm_v8bf16_ty>;
+def int_amdgcn_mfma_scale_f32_16x16x128_f8f6f4 : AMDGPUMfmaScaleIntrinsic<llvm_v4f32_ty, llvm_v8i32_ty>;
+def int_amdgcn_mfma_scale_f32_32x32x64_f8f6f4 : AMDGPUMfmaScaleIntrinsic<llvm_v16f32_ty, llvm_v8i32_ty>;
}
//===----------------------------------------------------------------------===//
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUGISel.td b/llvm/lib/Target/AMDGPU/AMDGPUGISel.td
index d348f489d95dd3..88fa96bd049f29 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUGISel.td
+++ b/llvm/lib/Target/AMDGPU/AMDGPUGISel.td
@@ -423,3 +423,6 @@ def gi_fp_pow2_to_exponent : GICustomOperandRenderer<"renderFPPow2ToExponent">,
def gi_as_hw_round_mode : GICustomOperandRenderer<"renderRoundMode">,
GISDNodeXFormEquiv<as_hw_round_mode>;
+
+def gi_MFMALdScaleModifierOp : GICustomOperandRenderer<"renderScaledMAIIntrinsicOperand">,
+ GISDNodeXFormEquiv<MFMALdScaleXForm>;
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp b/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp
index 3522ece24f1c45..58125f7f0831c6 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp
@@ -5737,6 +5737,18 @@ void AMDGPUInstructionSelector::renderRoundMode(MachineInstrBuilder &MIB,
MIB.addImm((MI.getOperand(OpIdx).getImm() + 3) % 4);
}
+/// Convert from 2-bit value to enum values used for op_sel* source modifiers.
+void AMDGPUInstructionSelector::renderScaledMAIIntrinsicOperand(
+ MachineInstrBuilder &MIB, const MachineInstr &MI, int OpIdx) const {
+ unsigned Val = MI.getOperand(OpIdx).getImm();
+ unsigned New = 0;
+ if (Val & 0x1)
+ New |= SISrcMods::OP_SEL_0;
+ if (Val & 0x2)
+ New |= SISrcMods::OP_SEL_1;
+ MIB.addImm(New);
+}
+
bool AMDGPUInstructionSelector::isInlineImmediate(const APInt &Imm) const {
return TII.isInlineConstant(Imm);
}
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.h b/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.h
index 42343104812b66..563e40267f04b1 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.h
+++ b/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.h
@@ -364,6 +364,8 @@ class AMDGPUInstructionSelector final : public InstructionSelector {
void renderRoundMode(MachineInstrBuilder &MIB, const MachineInstr &MI,
int OpIdx) const;
+ void renderScaledMAIIntrinsicOperand(MachineInstrBuilder &MIB,
+ const MachineInstr &MI, int OpIdx) const;
bool isInlineImmediate(const APInt &Imm) const;
bool isInlineImmediate(const APFloat &Imm) const;
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUInstructions.td b/llvm/lib/Target/AMDGPU/AMDGPUInstructions.td
index 671070c70f0c41..6a5065cd4a0e8f 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUInstructions.td
+++ b/llvm/lib/Target/AMDGPU/AMDGPUInstructions.td
@@ -40,7 +40,7 @@ class AMDGPUInst <dag outs, dag ins, string asm = "",
// instructions to not match without killing the whole decode process. It is
// mainly used for ARM, but Tablegen expects this field to exist or it fails
// to build the decode table.
- field bits<96> SoftFail = 0;
+ field bits<128> SoftFail = 0; // FIXME: If this is smaller than largest instruction, DecodeEmitter crashes
let DecoderNamespace = Namespace;
diff --git a/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp
index b648b68f3bd2b0..0467d26bd3093d 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp
@@ -4769,6 +4769,25 @@ AMDGPURegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
: getVGPROpMapping(MI.getOperand(4).getReg(), MRI, *TRI);
break;
}
+ case Intrinsic::amdgcn_mfma_scale_f32_16x16x128_f8f6f4:
+ case Intrinsic::amdgcn_mfma_scale_f32_32x32x64_f8f6f4: {
+ const SIMachineFunctionInfo *Info = MF.getInfo<SIMachineFunctionInfo>();
+ OpdsMapping[0] =
+ Info->mayNeedAGPRs()
+ ? getAGPROpMapping(MI.getOperand(0).getReg(), MRI, *TRI)
+ : getVGPROpMapping(MI.getOperand(0).getReg(), MRI, *TRI);
+
+ OpdsMapping[2] = getVGPROpMapping(MI.getOperand(2).getReg(), MRI, *TRI);
+ OpdsMapping[3] = getVGPROpMapping(MI.getOperand(3).getReg(), MRI, *TRI);
+ OpdsMapping[4] =
+ Info->mayNeedAGPRs()
+ ? getAGPROpMapping(MI.getOperand(4).getReg(), MRI, *TRI)
+ : getVGPROpMapping(MI.getOperand(4).getReg(), MRI, *TRI);
+
+ OpdsMapping[9] = getVGPROpMapping(MI.getOperand(9).getReg(), MRI, *TRI);
+ OpdsMapping[11] = getVGPROpMapping(MI.getOperand(11).getReg(), MRI, *TRI);
+ break;
+ }
case Intrinsic::amdgcn_smfmac_f32_16x16x32_f16:
case Intrinsic::amdgcn_smfmac_f32_32x32x16_f16:
case Intrinsic::amdgcn_smfmac_f32_16x16x32_bf16:
diff --git a/llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.cpp b/llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.cpp
index 7c293c1a5e512a..37cff2b9e1c959 100644
--- a/llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.cpp
+++ b/llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.cpp
@@ -493,6 +493,17 @@ static inline DecoderUInt128 eat12Bytes(ArrayRef<uint8_t> &Bytes) {
return DecoderUInt128(Lo, Hi);
}
+static inline DecoderUInt128 eat16Bytes(ArrayRef<uint8_t> &Bytes) {
+ assert(Bytes.size() >= 16);
+ uint64_t Lo =
+ support::endian::read<uint64_t, llvm::endianness::little>(Bytes.data());
+ Bytes = Bytes.slice(8);
+ uint64_t Hi =
+ support::endian::read<uint64_t, llvm::endianness::little>(Bytes.data());
+ Bytes = Bytes.slice(8);
+ return DecoderUInt128(Lo, Hi);
+}
+
DecodeStatus AMDGPUDisassembler::getInstruction(MCInst &MI, uint64_t &Size,
ArrayRef<uint8_t> Bytes_,
uint64_t Address,
@@ -529,6 +540,15 @@ DecodeStatus AMDGPUDisassembler::getInstruction(MCInst &MI, uint64_t &Size,
// Reinitialize Bytes
Bytes = Bytes_.slice(0, MaxInstBytesNum);
+
+ } else if (Bytes.size() >= 16 &&
+ STI.hasFeature(AMDGPU::FeatureGFX950Insts)) {
+ DecoderUInt128 DecW = eat16Bytes(Bytes);
+ if (tryDecodeInst(DecoderTableGFX940128, MI, DecW, Address, CS))
+ break;
+
+ // Reinitialize Bytes
+ Bytes = Bytes_.slice(0, MaxInstBytesNum);
}
if (Bytes.size() >= 8) {
diff --git a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCAsmInfo.cpp b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCAsmInfo.cpp
index 3c9f6d2938075b..56ed29ede02c23 100644
--- a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCAsmInfo.cpp
+++ b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCAsmInfo.cpp
@@ -59,6 +59,10 @@ unsigned AMDGPUMCAsmInfo::getMaxInstLength(const MCSubtargetInfo *STI) const {
if (STI->hasFeature(AMDGPU::FeatureNSAEncoding))
return 20;
+ // VOP3PX encoding.
+ if (STI->hasFeature(AMDGPU::FeatureGFX950Insts))
+ return 16;
+
// 64-bit instruction with 32-bit literal.
if (STI->hasFeature(AMDGPU::FeatureVOP3Literal))
return 12;
diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
index 5b02f9bf80d3fc..1e0dc7d1fd9bd6 100644
--- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
@@ -15449,6 +15449,23 @@ void SITargetLowering::AdjustInstrPostInstrSelection(MachineInstr &MI,
MRI.setRegClass(Op.getReg(), NewRC);
}
+ if (TII->isMAI(MI)) {
+ // The ordinary src0, src1, src2 were legalized above.
+ //
+ // We have to also legalize the appended v_mfma_ld_scale_b32 operands,
+ // as a separate instruction.
+ int Src0Idx = AMDGPU::getNamedOperandIdx(MI.getOpcode(),
+ AMDGPU::OpName::scale_src0);
+ if (Src0Idx != -1) {
+ int Src1Idx = Src0Idx + 2;
+ assert(Src1Idx = AMDGPU::getNamedOperandIdx(
+ MI.getOpcode(), AMDGPU::OpName::scale_src1));
+ if (TII->usesConstantBus(MRI, MI, Src0Idx) &&
+ TII->usesConstantBus(MRI, MI, Src1Idx))
+ TII->legalizeOpWithMove(MI, Src1Idx);
+ }
+ }
+
if (!HasAGPRs)
return;
diff --git a/llvm/lib/Target/AMDGPU/SIInstrFormats.td b/llvm/lib/Target/AMDGPU/SIInstrFormats.td
index dd1ab2c628715d..267c9a94b90968 100644
--- a/llvm/lib/Target/AMDGPU/SIInstrFormats.td
+++ b/llvm/lib/Target/AMDGPU/SIInstrFormats.td
@@ -300,6 +300,11 @@ class Enc96 {
int Size = 12;
}
+class Enc128 {
+ field bits<128> Inst;
+ int Size = 16;
+}
+
def CPolBit {
int GLC = 0;
int SLC = 1;
diff --git a/llvm/lib/Target/AMDGPU/SIInstrInfo.h b/llvm/lib/Target/AMDGPU/SIInstrInfo.h
index 1f7fff76d15210..e55418326a4bd0 100644
--- a/llvm/lib/Target/AMDGPU/SIInstrInfo.h
+++ b/llvm/lib/Target/AMDGPU/SIInstrInfo.h
@@ -1115,6 +1115,12 @@ class SIInstrInfo final : public AMDGPUGenInstrInfo {
const MachineOperand &MO,
const MCOperandInfo &OpInfo) const;
+ bool usesConstantBus(const MachineRegisterInfo &MRI, const MachineInstr &MI,
+ int OpIdx) const {
+ return usesConstantBus(MRI, MI.getOperand(OpIdx),
+ MI.getDesc().operands()[OpIdx]);
+ }
+
/// Return true if this instruction has any modifiers.
/// e.g. src[012]_mod, omod, clamp.
bool hasModifiers(unsigned Opcode) const;
diff --git a/llvm/lib/Target/AMDGPU/SIInstrInfo.td b/llvm/lib/Target/AMDGPU/SIInstrInfo.td
index d2024cf915874d..08e29ab78cceaf 100644
--- a/llvm/lib/Target/AMDGPU/SIInstrInfo.td
+++ b/llvm/lib/Target/AMDGPU/SIInstrInfo.td
@@ -914,6 +914,16 @@ def fp16_zeros_high_16bits : PatLeaf<(f16 VGPR_32:$src), [{
return fp16SrcZerosHighBits(N->getOpcode());
}]>;
+def MFMALdScaleXForm : SDNodeXForm<timm, [{
+ unsigned Val = N->getZExtValue();
+ unsigned New = 0;
+ if (Val & 0x1)
+ New |= SISrcMods::OP_SEL_0;
+ if (Val & 0x2)
+ New |= SISrcMods::OP_SEL_1;
+ return CurDAG->getTargetConstant(New, SDLoc(N), MVT::i32);
+}]>;
+
def is_canonicalized : PatLeaf<(fAny srcvalue:$src), [{
const SITargetLowering &Lowering =
*static_cast<const SITargetLowering *>(getTargetLowering());
@@ -1515,6 +1525,10 @@ class PackedIntInputMods <PackedIntInputModsMatchClass matchClass> : InputMods <
def PackedF16InputMods : PackedFPInputMods<PackedF16InputModsMatchClass>;
def PackedI16InputMods : PackedIntInputMods<PackedI16InputModsMatchClass>;
+def MFMALdScaleModifierOp : TImmLeaf<i32, [{
+ return isUInt<2>(Imm);
+}], MFMALdScaleXForm>;
+
//===----------------------------------------------------------------------===//
// Complex patterns
//===----------------------------------------------------------------------===//
@@ -2851,6 +2865,8 @@ def VOP_V16F32_V2I32_V4I32_I32 : VOPProfile <[v16f32, v2i32, v4i32, i32]>;
def VOP_V4F32_V8F16_V8F16_V4F32 : VOPProfile <[v4f32, v8f16, v8f16, v4f32]>;
def VOP_V16F32_V8F16_V8F16_V16F32 : VOPProfile <[v16f32, v8f16, v8f16, v16f32]>;
def VOP_V16F32_V8BF16_V8BF16_V16F32 : VOPProfile <[v16f32, v8bf16, v8bf16, v16f32]>;
+def VOP_V4F32_V8I32_V8I32_V4F32 : VOPProfile <[v4f32, v8i32, v8i32, v4f32]>;
+def VOP_V16F32_V8I32_V8I32_V16F32 : VOPProfile <[v16f32, v8i32, v8i32, v16f32]>;
class Commutable_REV <string revOp, bit isOrig> {
diff --git a/llvm/lib/Target/AMDGPU/SIRegisterInfo.td b/llvm/lib/Target/AMDGPU/SIRegisterInfo.td
index d47ff9fe96c94e..03d85b6081eada 100644
--- a/llvm/lib/Target/AMDGPU/SIRegisterInfo.td
+++ b/llvm/lib/Target/AMDGPU/SIRegisterInfo.td
@@ -1338,11 +1338,13 @@ class AVSrcOperand<RegisterClass regClass, string width>
def AVSrc_32 : AVSrcOperand<AV_32, "OPW32">;
def AVSrc_64 : AVSrcOperand<AV_64, "OPW64">;
def AVSrc_128 : AVSrcOperand<AV_128, "OPW128">;
+def AVSrc_256 : AVSrcOperand<AV_256, "OPW256">;
class AVDstOperand<RegisterClass regClass, string width>
: AVOperand<regClass, "decodeAV10", width>;
def AVDst_128 : AVDstOperand<AV_128, "OPW128">;
+def AVDst_256 : AVDstOperand<AV_256, "OPW256">;
def AVDst_512 : AVDstOperand<AV_512, "OPW512">;
class AVLdStOperand<RegisterClass regClass, string width>
diff --git a/llvm/lib/Target/AMDGPU/VOP3PInstructions.td b/llvm/lib/Target/AMDGPU/VOP3PInstructions.td
index 3a6202ea435222..9d39d3654c3ff0 100644
--- a/llvm/lib/Target/AMDGPU/VOP3PInstructions.td
+++ b/llvm/lib/Target/AMDGPU/VOP3PInstructions.td
@@ -639,14 +639,26 @@ def VOPProfileMAI_F32_V8F16_X16_VCD : VOPProfileMAI<VOP_V16F32_V8F16_V8F16_V16F3
def VOPProfileMAI_F32_V8BF16_X16 : VOPProfileMAI<VOP_V16F32_V8BF16_V8BF16_V16F32, AISrc_512_f32, ADst_512, AVSrc_128>;
def VOPProfileMAI_F32_V8BF16_X16_VCD : VOPProfileMAI<VOP_V16F32_V8BF16_V8BF16_V16F32, VISrc_512_f32, VDst_512, AVSrc_128>;
+// For f32_16x16x128_f8f6f4
+def VOPProfileMAI_F32_V8I32_X128 : VOPProfileMAI<VOP_V4F32_V8I32_V8I32_V4F32, AISrc_128_f32, ADst_128, AVSrc_256>;
+def VOPProfileMAI_F32_V8I32_X128_VCD : VOPProfileMAI<VOP_V4F32_V8I32_V8I32_V4F32, VISrc_128_f32, VDst_128, AVSrc_256>;
+
+// For f32_32x32x64_f8f6f4
+def VOPProfileMAI_F32_V8I32_X512 : VOPProfileMAI<VOP_V16F32_V8I32_V8I32_V16F32, AISrc_512_f32, ADst_512, AVSrc_256>;
+def VOPProfileMAI_F32_V8I32_X512_VCD : VOPProfileMAI<VOP_V16F32_V8I32_V8I32_V16F32, VISrc_512_f32, VDst_512, AVSrc_256>;
+
class MFMATable <bit is_mac, string Name> {
bit IsMac = is_mac;
string FMAOp = Name;
}
-class MAIFrag<SDPatternOperator Op, code pred> : PatFrag <
- (ops node:$src0, node:$src1, node:$src2, node:$cbsz, node:$abid, node:$blgp),
- (Op $src0, $src1, $src2, $cbsz, $abid, $blgp),
+class MAIFrag<SDPatternOperator Op, code pred, bit Scaled = false> : PatFrag <
+ !if(Scaled, (ops node:$src0, node:$src1, node:$src2, node:$cbsz, node:$abid, node:$blgp,
+ node:$scale_src0_opsel, node:$scale_src0,
+ node:$scale_src1_opsel, node:$scale_src1),
+ (ops node:$src0, node:$src1, node:$src2, node:$cbsz, node:$abid, node:$blgp)),
+ !if(Scaled, (Op $src0, $src1, $src2, $cbsz, $abid, $blgp, $scale_src0_opsel, $scale_src0, $scale_src1_opsel, $scale_src1),
+ (Op $src0, $src1, $src2, $cbsz, $abid, $blgp)),
pred
>;
@@ -666,11 +678,11 @@ defvar MayNotNeedAGPRs_gisel = [{
return !MF.getInfo<SIMachineFunctionInfo>()->mayNeedAGPRs();
}];
-class AgprMAIFrag<SDPatternOperator Op> : MAIFrag<Op, MayNeedAGPRs> {
+class AgprMAIFrag<SDPatternOperator Op, bit Scaled = false> : MAIFrag<Op, MayNeedAGPRs, Scaled> {
let GISelPredicateCode = MayNeedAGPRs_gisel;
}
-class VgprMAIFrag<SDPatternOperator Op> : MAIFrag<Op, MayNotNeedAGPRs> {
+class VgprMAIFrag<SDPatternOperator Op, bit Scaled = false> : MAIFrag<Op, MayNotNeedAGPRs, Scaled> {
let GISelPredicateCode = MayNotNeedAGPRs_gisel;
}
@@ -683,26 +695,47 @@ let isAsCheapAsAMove = 1, isReMaterializable = 1 in {
} // End isMoveImm = 1
} // End isAsCheapAsAMove = 1, isReMaterializable = 1
-class MAIInst<string OpName, VOPProfile P, SDPatternOperator node>
- : VOP3InstBase<OpName, P, node> {
+class MAIInst<string OpName, VOPProfile P, SDPatternOperator node, bit Scaled = false>
+ : VOP3InstBase<OpName, P, node, /*IsVOP2=*/0, Scaled> {
Instruction Opcode = !cast<Instruction>(NAME);
bit is_dgemm = 0;
bit is_gfx940_xdl = 0;
}
-multiclass MAIInst<string OpName, string P, SDPatternOperator node> {
- defvar NoDstOverlap = !cast<VOPProfileMAI>("VOPProfileMAI_" # P).NoDstOverlap;
-
+// FIXME: Intrinsic should probably not have op_sel operands, we can
+// pattern match byte select patterns into op_sel.
+// FIXME: Missing neg and clamp modifiers
+//
+// FIXME: Usual syntax for op_sel is quite hostile here.
+class ScaledMAIInst<string OpName, MAIInst BaseInst, SDPatternOperator node> :
+ MAIInst<OpName, BaseInst.Pfl, node, /*Scaled=*/true> {
+ // Append operands from V_MFMA_LD_SCALE_B32, but we need to rename them.
+ let InOperandList = !con(BaseInst.InOperandList,
+ (ins VSrc_b32:$scale_src0,
+ VSrc_b32:$scale_src1,
+ op_sel0:$scale_src0_opsel,
+ op_sel_hi0:$scale_src1_opsel));
+ let AsmOperands =
+ "$vdst, $src0, $src1, $src2, $scale_src0, $scale_src1"
+ "$scale_src0_opsel$scale_src1_opsel$cbsz$abid$blgp";
+
+ let FixedSize = 1;
+ let Size = 16;
+}
+
+multiclass MAIInst<string OpName, string P, SDPatternOperator node = null_frag,
+ bit NoDstOverlap = !cast<VOPProfileMAI>("VOPProfileMAI_" # P).NoDstOverlap,
+ bit Scaled = false> {
let isConvergent = 1, mayRaiseFPException = 0, ReadsModeReg = 1 in {
// FP32 denorm mode is respected, rounding mode is not. Exceptions are not supported.
let Constraints = !if(NoDstOverlap, "@earlyclobber $vdst", "") in {
def _e64 : MAIInst<OpName, !cast<VOPProfileMAI>("VOPProfileMAI_" # P),
- !if(!or(NoDstOverlap, !eq(node, null_frag)), nul...
[truncated]
|
These use a new VOP3PX encoding for the v_mfma_scale_* instructions,
which bundles the pre-scale v_mfma_ld_scale_b32. None of the modifiers
are supported yet (op_sel, neg or clamp).
I'm not sure the intrinsic should really expose op_sel (or any of the
others). If I'm reading the documentation correctly, we should be able
to just have the raw scale operands and auto-match op_sel to byte
extract patterns.
The op_sel syntax also seems extra horrible in this usage, especially with the
usual assumed op_sel_hi=-1 behavior.